Skip to content

Distributed Neighborhood Attention S2#161

Open
azrael417 wants to merge 18 commits intomainfrom
tkurth/distributed-neighborhood-attention
Open

Distributed Neighborhood Attention S2#161
azrael417 wants to merge 18 commits intomainfrom
tkurth/distributed-neighborhood-attention

Conversation

@azrael417
Copy link
Copy Markdown
Collaborator

This MR adds distributed Neighborhood Attention S2 support and fixes some issues in the existing attention kernel.

  • existing serial attention kernel preallocated an output tensor which was too large when using attention based downsampling. This is fixed
  • existing serial attention kernel does not produce the correct v gradient when used in upsampling. We will fix that next
  • this MR adds distributed neighborhood attention along with some new tests for the feature. This kernel does not support up- or downsampling yet

@azrael417 azrael417 requested review from bonevbs March 30, 2026 16:10
@azrael417 azrael417 self-assigned this Mar 30, 2026
@azrael417 azrael417 force-pushed the tkurth/distributed-neighborhood-attention branch 3 times, most recently from b75d634 to 4d434d6 Compare April 6, 2026 07:45
@bonevbs
Copy link
Copy Markdown
Collaborator

bonevbs commented Apr 15, 2026

Please bump the version number to 0.9.1a and start the Changelog for v0.9.1.

@azrael417 azrael417 force-pushed the tkurth/distributed-neighborhood-attention branch from da78cae to 46fb343 Compare April 16, 2026 08:23
@azrael417 azrael417 force-pushed the tkurth/distributed-neighborhood-attention branch from 46fb343 to b9a386e Compare April 16, 2026 10:52
@azrael417 azrael417 marked this pull request as ready for review April 16, 2026 12:40
Copy link
Copy Markdown
Collaborator

@bonevbs bonevbs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please bump the version number to 0.9.1a1 and start the Changelog for v0.9.1

Also some minor comments

Comment thread torch_harmonics/attention/csrc/attention_cuda_bwd.cu Outdated
Comment thread torch_harmonics/distributed/primitives.py Outdated
def distributed_transpose_polar(input_, dims_, shapes_):
return _DistributeTransposePolar.apply(input_, dims_, shapes_)

@torch.compiler.disable()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need those?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is important so that the graph will properly break here

@bonevbs
Copy link
Copy Markdown
Collaborator

bonevbs commented Apr 16, 2026

@rietmann-nv can you also have a look at the new CUDA kernels for distributed spherical attention

@bonevbs bonevbs requested a review from apaaris April 16, 2026 12:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants